import copy
import json
import math
import numpy
import random

import utils

class Grammar:
  def __init__(self):
    self.n = 0 # Amount of terminal symbols, used to compute description length
    self.r = 0 # Amount of non terminal symbols (which are refered to by their index)
    self.T = [] # The set of terminal symbols
    self.TP = [] # Map between non terminal and list of terminals
    self.NTP = [] # Map between non terminal and list of pairs of nonterminals

  def description_length(self):
    """ Returns description length of grammar. Depended on amount of rules and variables. """
    logr = utils.bits_required(self.r)
    logn = utils.bits_required(self.n)
    sum_elem = 0
    for x in self.TP:
      sum_elem += (len(x) + 1) * logn # +1 for segmenting the rules
    for x in self.NTP:
      sum_elem += len(x) * logr
    return sum_elem + logr # +logr because we need unary representation of amount of non terminals

  def copy(self):
    """ Returns a deep copy of the grammar """
    g = Grammar()
    g. n = self.n
    g.r = self.r
    g.T = copy.deepcopy(self.T)
    g.TP = copy.deepcopy(self.TP)
    g.NTP = copy.deepcopy(self.NTP)
    return g

  # Serialization
  def to_json(self):
    return json.dumps(self.__dict__)

  def get_neighbor(self):
    new_grammar = self.copy()
    random.seed()
    while True:
      transformation = random.choice([new_grammar.add_non_terminal, new_grammar.add_ntp, new_grammar.add_tp, new_grammar.remove_ntp, new_grammar.remove_tp, new_grammar.edit_ntp])
      is_changed = transformation()
      print(is_changed)
      if is_changed:
        return new_grammar

  def add_non_terminal(self):
    """ Add a non terminal without any associated rules """
    print('add nt')
    self.TP.append([random.choice(self.T)])
    self.NTP.append([[random.randrange(self.r), random.randrange(self.r)]])
    self.r += 1
    return True

  def add_ntp(self):
    """ Pick three non terminals and bind them in a rule. If rule already exists does nothing """
    print('add ntp')
    X = random.randrange(self.r)
    Y = random.randrange(self.r)
    Z = random.randrange(self.r)
    if [Y, Z] not in self.NTP[X]:
      self.NTP[X].append([Y, Z])
      return True
    else:
      return False


  def add_tp(self):
    """ Pick a nonterminal and add a new terminal that can be derived from it """
    """ If nonterminal has all terminals already, or terminal picked is already there, does nothing """
    print('add tp')
    X = random.randrange(self.r)
    a = random.choice(self.T)
    if len(self.TP[X]) == len(self.T):
      return False
    if a not in self.TP[X]:
      self.TP[X].append(a)
      return True
    return False

  def remove_ntp(self):
    """ Picks a nonterminal and removes a non terminal production rule """
    print('remove ntp')
    X = random.randrange(self.r)
    if len(self.NTP[X]) > 0:
      self.NTP[X].pop(random.randrange(len(self.NTP[X])))
      return True
    return False

  def remove_tp(self):
    """ Pick a nonterminal and remove a terminal that can be derived from it """
    print('remove tp')
    X = random.randrange(self.r)
    if len(self.TP[X]) > 0:
      self.TP[X].pop(random.randrange(len(self.TP[X])))
      return True
    return False

  def edit_ntp(self):
    """ Picks a random non terminal, then picks one of its production rules and changes one of its derivatives """
    """ may change to  the same nonterminal thus leaving the rule untouched """
    print('edit ntp')
    X = random.randrange(self.r)
    if len(self.NTP[X]) > 0:
      rule = random.choice(self.NTP[X])
      substitution = random.randrange(self.r)
      position = random.randrange(2)
      if rule[position] != substitution:
        rule[position] = substitution
        return True
    return False

  def parse(self, S):
    table = numpy.zeros(shape=[len(S), len(S)+1, self.r], dtype=bool)
     # |backtrack| stores the shortest parse tree.
     # last coordinate in backtrack represents k, nonterminal left and nonterminal right
    backtrack = numpy.empty(shape=[len(S), len(S)+1, self.r, 3])
    backtrack[:] = numpy.NAN
    lengths = numpy.empty(shape=[len(S), len(S)+1, self.r])
    lengths[:] = numpy.NAN
    for i in range(len(S)):
      for nonterminal in range(self.r):
        if S[i] in self.TP[nonterminal]:
          table[i, 1, nonterminal] = True
          lengths[i, 1, nonterminal] = utils.bits_required(len(self.TP[nonterminal]) + len(self.NTP[nonterminal]))
    for i in range(2, len(S) + 1): # Length of span
      for j in range(len(S) - i + 1): # Start of span
        for k in range(1, i): # Partition of span
          for l in range(self.r): # Non terminal
            for pair in self.NTP[l]: # Rule
              if table[j, k, pair[0]] and table[j+k, i-k, pair[1]]:
                table[j, i, l] = True
                current_length = lengths[j, k, pair[0]] + lengths[j+k, i-k, pair[1]]
                if lengths[j, i, l] > current_length or numpy.isnan(lengths[j, i, l]):
                  backtrack[j, i, l, :] = [k, pair[0], pair[1]]
                  lengths[j, i, l] = current_length + utils.bits_required(len(self.TP[l]) + len(self.NTP[l]))

    #if any(table[0, len(S), :]):
    # print(lengths[0, len(S), :])

    parse_length = lengths[0, len(S), 0]
    return False if numpy.isnan(parse_length) else parse_length


def create_initial_cfg(terminals):
  """ Basically the rules would be the following two rules: X -> X X and X -> a for each terminal 'a' in terminals """
  g = Grammar()
  g.n = len(terminals)
  g.r = 1
  g.T = terminals
  g.TP.append(terminals)
  g.NTP.append([[0, 0]])
  return g

def from_json(encoded_grammar):
  g = Grammar()
  attrs = json.loads(encoded_grammar)
  for k in attrs:
    setattr(g, k, attrs[k])
  return g

  def earley(grammar, S):
    pass